ARG IMAGE_NAME
FROM ${IMAGE_NAME}:{{ cuda.version.major }}.{{ cuda.version.minor }}-base-{{ cuda.os.distro }}{{ cuda.os.version }}{{ cuda.image_tag_suffix if "image_tag_suffix" in cuda }} as base

{% if "x86_64" in cuda %}
FROM base as base-amd64

ENV NV_CUDA_LIB_VERSION {{ cuda.x86_64.components.libraries.version }}
ENV NV_NVTX_VERSION {{ cuda.x86_64.components.nvtx.version }}
ENV NV_LIBNPP_VERSION {{ cuda.x86_64.components.libnpp.version }}
ENV NV_LIBCUSPARSE_VERSION {{ cuda.x86_64.components.libcusparse.version }}

{% if "libcublas" in cuda.x86_64.components %}
    {% if cuda.version.major_minor in ["10.1", "10.2"] %}
        {# {% set l4t = (cuda.arch == "arm64" and cuda.version.build != 89) %} #}
        {# {% set suffix = "10" if not l4t else "-10-2" %} #}
        {% set suffix = "10" %}
    {#     {% set cublas_vers = suffix + "=${NV_LIBCUBLAS_VERSION}" %} #}
    {# {% else %} #}
    {#     {% set cublas_vers = "-${NV_LIBCUBLAS_VERSION}" %} #}
    {% endif %}
{% endif %}

{% if cuda.version.major_minor in ["9.0", "9.1", "9.2", "10.0"] %}
    {# [> This should never change at this point <] #}
    {# {% if cuda.version.major_minor == "9.0" %} #}
    {#     {% set cublas_vers = "-9-0=9.0.176.4-1" %} #}
    {# {% endif %} #}
    {# {% if cuda.version.major_minor == "9.1" %}{% set cublas_vers = "-9-1=9.1.85.3-1" %}{% endif %} #}
    {# {% if cuda.version.major_minor == "9.2" %}{% set cublas_vers = "-9-2=9.2.148.1-1" %}{% endif %} #}
ENV NV_LIBCUBLAS_PACKAGE_NAME cuda-cublas-{{ cuda.version.major }}-{{ cuda.version.minor }}
{% else %}
ENV NV_LIBCUBLAS_PACKAGE_NAME libcublas{{ suffix }}
{% endif %}

ENV NV_LIBCUBLAS_VERSION {{ cuda.x86_64.components.libcublas.version }}
ENV NV_LIBCUBLAS_PACKAGE ${NV_LIBCUBLAS_PACKAGE_NAME}=${NV_LIBCUBLAS_VERSION}

{# [> L4T uses older packaging name format for libcublas <] #}
{# {%- if cuda.version.major_minor in ["10.2"] %} #}
{#     {% set libcublas_pkg_name = "libcublas10" %} #}
{#     {% set libcublas_component_version = cuda.x86_64.components.libcublas.version %} #}
{# {% endif -%} #}

{% if "libnccl2" in cuda.x86_64.components and cuda.x86_64.components.libnccl2 %}
    {% set nccl_package = 1 %}
ENV NV_LIBNCCL_PACKAGE_NAME "libnccl2"
ENV NV_LIBNCCL_PACKAGE_VERSION {{ cuda.x86_64.components.libnccl2.version }}
ENV NCCL_VERSION {{ cuda.x86_64.components.libnccl2.version[:-2] }}
ENV NV_LIBNCCL_PACKAGE ${NV_LIBNCCL_PACKAGE_NAME}=${NV_LIBNCCL_PACKAGE_VERSION}+cuda{{ cuda.version.major }}.{{ cuda.version.minor }}
{% else %}
    {% set nccl_package = 0 %}
{% endif -%}

{% endif %}
{% if "ppc64le" in cuda %}
FROM base as base-ppc64le

ENV NV_CUDA_LIB_VERSION {{ cuda.x86_64.components.libraries.version }}
ENV NV_NVTX_VERSION {{ cuda.ppc64le.components.nvtx.version }}
ENV NV_LIBNPP_VERSION {{ cuda.ppc64le.components.libnpp.version }}
ENV NV_LIBCUSPARSE_VERSION {{ cuda.ppc64le.components.libcusparse.version }}

{% if "libcublas" in cuda.ppc64le.components %}
    {% if cuda.version.major_minor in ["10.1", "10.2"] %}
        {# {% set l4t = (cuda.arch == "arm64" and cuda.version.build != 89) %} #}
        {# {% set suffix = "10" if not l4t else "-10-2" %} #}
        {% set suffix = "10" %}
    {#     {% set cublas_vers = suffix + "=${NV_LIBCUBLAS_VERSION}" %} #}
    {# {% else %} #}
    {#     {% set cublas_vers = "-${NV_LIBCUBLAS_VERSION}" %} #}
    {% endif %}
{% endif %}

{% if cuda.version.major_minor in ["9.0", "9.1", "9.2", "10.0"] %}
    {# [> This should never change at this point <] #}
    {# {% if cuda.version.major_minor == "9.0" %} #}
    {#     {% set cublas_vers = "-9-0=9.0.176.4-1" %} #}
    {# {% endif %} #}
    {# {% if cuda.version.major_minor == "9.1" %}{% set cublas_vers = "-9-1=9.1.85.3-1" %}{% endif %} #}
    {# {% if cuda.version.major_minor == "9.2" %}{% set cublas_vers = "-9-2=9.2.148.1-1" %}{% endif %} #}
ENV NV_LIBCUBLAS_PACKAGE_NAME cuda-cublas-{{ cuda.version.major }}-{{ cuda.version.minor }}
{% else %}
ENV NV_LIBCUBLAS_PACKAGE_NAME libcublas{{ suffix }}
{% endif %}

ENV NV_LIBCUBLAS_VERSION {{ cuda.ppc64le.components.libcublas.version }}
ENV NV_LIBCUBLAS_PACKAGE ${NV_LIBCUBLAS_PACKAGE_NAME}=${NV_LIBCUBLAS_VERSION}

{% if "libnccl2" in cuda.ppc64le.components and cuda.ppc64le.components.libnccl2 %}
    {% set nccl_package = 1 %}
ENV NV_LIBNCCL_PACKAGE_NAME "libnccl2"
ENV NV_LIBNCCL_PACKAGE_VERSION {{ cuda.ppc64le.components.libnccl2.version }}
ENV NCCL_VERSION {{ cuda.ppc64le.components.libnccl2.version[:-2] }}
ENV NV_LIBNCCL_PACKAGE ${NV_LIBNCCL_PACKAGE_NAME}=${NV_LIBNCCL_PACKAGE_VERSION}+cuda{{ cuda.version.major }}.{{ cuda.version.minor }}
{% else %}
    {% set nccl_package = 0 %}
{% endif -%}
{% endif -%}

FROM base-${TARGETARCH}

ARG TARGETARCH

LABEL maintainer "NVIDIA CORPORATION <cudatools@nvidia.com>"

RUN apt-get update && apt-get install -y --no-install-recommends \
    cuda-libraries-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_CUDA_LIB_VERSION} \
    {% if (cuda.arch != "arm64" and cuda.version.major | int <= 10) or cuda.version.build == 89 %}
    cuda-npp-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_LIBNPP_VERSION} \
    {% elif cuda.version.major | int > 10 %}
    libnpp-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_LIBNPP_VERSION} \
    {% endif %}
    {% if (cuda.version.major_minor == "9.2" ) or cuda.version.major | int >= 10 %}
    cuda-nvtx-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_NVTX_VERSION} \
    {% endif %}
    cuda-cusparse-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_LIBCUSPARSE_VERSION} \
    ${NV_LIBCUBLAS_PACKAGE} \
    {% if nccl_package == 1 %}
    ${NV_LIBNCCL_PACKAGE} \
    {% endif %}
    && rm -rf /var/lib/apt/lists/*

# Keep apt from auto upgrading the cublas and nccl packages. See https://gitlab.com/nvidia/container-images/cuda/-/issues/88
RUN apt-mark hold ${NV_LIBNCCL_PACKAGE_NAME} ${NV_LIBCUBLAS_PACKAGE_NAME}
